{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Composable Stable diffusion\n",
    "\n",
    "[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models. This script was contributed by [MarkRich](https://github.com/MarkRich) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install torch numpy torchvision diffusers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3dadcf1262e0492cafe9556f62ba3a9f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "composable_stable_diffusion.py:   0%|          | 0.00/27.6k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "279a467d562041ec935edacbf177caba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "composing ['mystical trees', 'A magical pond', 'dark']...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3717298308004b648b65d6c1b1e02dbe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image saved successfully!\n"
     ]
    }
   ],
   "source": [
    "import torch as th\n",
    "import numpy as np\n",
    "import torchvision.utils as tvu\n",
    "from diffusers import DiffusionPipeline\n",
    "import argparse\n",
    "import sys\n",
    "\n",
    "# Simulate passing arguments explicitly (bypassing Jupyter's arguments)\n",
    "sys.argv = [\n",
    "    \"ipykernel_launcher.py\", \n",
    "    \"--prompt\", \"mystical trees | A magical pond | dark\",\n",
    "    \"--steps\", \"50\",\n",
    "    \"--scale\", \"7.5\",\n",
    "    \"--weights\", \"7.5 | 7.5 | -7.5\",\n",
    "    \"--seed\", \"2\",\n",
    "    \"--model_path\", \"CompVis/stable-diffusion-v1-4\",\n",
    "    \"--num_images\", \"1\"\n",
    "]\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"--prompt\", type=str, default=\"mystical trees | A magical pond | dark\",\n",
    "                    help=\"use '|' as the delimiter to compose separate sentences.\")\n",
    "parser.add_argument(\"--steps\", type=int, default=50)\n",
    "parser.add_argument(\"--scale\", type=float, default=7.5)\n",
    "parser.add_argument(\"--weights\", type=str, default=\"7.5 | 7.5 | -7.5\")\n",
    "parser.add_argument(\"--seed\", type=int, default=2)\n",
    "parser.add_argument(\"--model_path\", type=str, default=\"CompVis/stable-diffusion-v1-4\")\n",
    "parser.add_argument(\"--num_images\", type=int, default=1)\n",
    "args = parser.parse_args()\n",
    "\n",
    "# CUDA Setup\n",
    "has_cuda = th.cuda.is_available()\n",
    "device = th.device('cpu' if not has_cuda else 'cuda')\n",
    "\n",
    "# Assign parameters\n",
    "prompt = args.prompt\n",
    "scale = args.scale\n",
    "steps = args.steps\n",
    "\n",
    "# Load pipeline\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    args.model_path,\n",
    "    custom_pipeline=\"composable_stable_diffusion\",\n",
    ").to(device)\n",
    "\n",
    "# Disable safety checker (if intentional for internal use)\n",
    "pipe.safety_checker = None\n",
    "\n",
    "# Generate images\n",
    "images = []\n",
    "generator = th.Generator(\"cuda\").manual_seed(args.seed)\n",
    "for i in range(args.num_images):\n",
    "    image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps,\n",
    "                 weights=args.weights, generator=generator).images[0]\n",
    "    images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)\n",
    "\n",
    "# Create and save image grid\n",
    "grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)\n",
    "tvu.save_image(grid, f'{prompt}_{args.weights}.png')\n",
    "\n",
    "print(\"Image saved successfully!\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
